from enum import Enum
from struct import pack, unpack
import os

# Packet header:
# u16 magic 0x239a
# u16 type
# u32 data length
# u64 serial


class PacketType(Enum):
    PT_COMMAND = 1
    PT_DATA = 2
    PT_CONSOLE = 3
    PT_STATUS = 4
    PT_STATUS_DATA = 5
    PT_SETUP = 6
    PT_ABORT = 7


class PacketDirection(Enum):
    PT_DIR_OUT = 1,  # from host to device
    PT_DIR_IN = 2,  # from device to host


class Packet:
    HEADER_SIZE = 16
    PACKET_MAGIC = 0x239a
    next_serial = 0

    def __init__(self, type, data: bytes = None):
        if isinstance(type, PacketType):
            self.direction = PacketDirection.PT_DIR_OUT
            self.type = type
            if data is not None:
                self.data = data
            else:
                self.data = b''

            self.serial = Packet.next_serial
            Packet.next_serial += 1
        else:
            if data is not None:
                raise ValueError()

            self.direction = PacketDirection.PT_DIR_IN

    def write(self, ch):
        if self.direction != PacketDirection.PT_DIR_OUT:
            raise RuntimeError('self.direction != PacketDirection.PT_DIR_OUT')

        header = pack(
            "<HHIQ", *(Packet.PACKET_MAGIC, self.type.value, len(self.data), self.serial))
        assert len(header) == Packet.HEADER_SIZE
        ch.write(header + self.data)

    def read(self, ch):
        if self.direction != PacketDirection.PT_DIR_IN:
            raise RuntimeError('self.direction != PacketDirection.PT_DIR_IN')

        # FIXME: On Windows data tends to get lost randomly
        if os.name == 'nt':
            while True:
                magic = unpack("<H", ch.read(2))[0]
                if magic == Packet.PACKET_MAGIC:
                    break
                print('\x1b[1;31mE\x1b[0m', end='', flush=True)

            type, data_length, serial = unpack(
                "<HIQ", ch.read(Packet.HEADER_SIZE-2))

            if data_length > 0:
                data = ch.read(data_length)
            else:
                data = b''
        else:
            magic, type, data_length, serial = unpack(
                "<HHIQ", ch.read(Packet.HEADER_SIZE))
            if magic != Packet.PACKET_MAGIC:
                print('magic       =', magic)
                print('type        =', type)
                print('data length =', data_length)
                print('serial      =', serial)
                assert magic == Packet.PACKET_MAGIC
            if data_length > 0:
                data = ch.read(data_length)
            else:
                data = b''

        try:
            type = PacketType(type)
        except ValueError:
            raise RuntimeError('protocol error, unknown packet type %x' % type)

        if type == PacketType.PT_STATUS:
            if data_length != 4:
                raise RuntimeError(
                    'protocol error, data length {}'.format(data_length))

            self.status = unpack("<I", data[:4])[0]
            data = data[4:]
        elif type == PacketType.PT_STATUS_DATA:
            if data_length < 4:
                raise RuntimeError('protocol error')

            self.status = unpack("<I", data[:4])[0]
            data = data[4:]
        elif data_length == 0:
            raise RuntimeError('protocol error')

        self.type = type

        self.data = data
        self.serial = serial
